{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "03 - Build an Embeddings index from a data source",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WDbhGHtG8jFE",
        "colab_type": "text"
      },
      "source": [
        "# Part 3: Build an Embeddings index from a data source\n",
        "\n",
        "In Part 1, we gave a general overview of txtai, the backing technology and examples of how to use it for similarity searches. Part 2 covered how to use txtai for extractive question-answer systems.\n",
        "\n",
        "The previous examples worked on data stored in memory for demo purposes. For real world large-scale use cases, data is usually stored in a database (Elasticsearch, SQL, MongoDB, files, etc). This example covers reading data from SQLite, building a Embedding index backed by word embeddings and running queries against the generated Embeddings index.\n",
        "\n",
        "This example covers functionality found in the [paperai](https://github.com/neuml/paperai) library. See that library for a full solution that can be used with the dataset discussed below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UQ0fCwXn9bcH",
        "colab_type": "text"
      },
      "source": [
        "# Install dependencies\n",
        "\n",
        "Install txtai and all dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "czPYSA2Q9ZHO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%%capture\n",
        "!pip install git+https://github.com/neuml/txtai"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SN9SCZKQ9fJF",
        "colab_type": "text"
      },
      "source": [
        "# Download data\n",
        "\n",
        "This example is going to work off a subset of the [CORD-19](https://www.semanticscholar.org/cord19) dataset. COVID-19 Open Research Dataset (CORD-19) is a free resource of scholarly articles, aggregated by a coalition of leading research groups, covering COVID-19 and the coronavirus family of viruses.\n",
        "\n",
        "The following download is SQLite database with a subject of CORD-19, generated from a [Kaggle notebook](https://www.kaggle.com/davidmezzetti/cord-19-slim/output). More information on this data format, can be found in the [CORD-19 Analysis](https://www.kaggle.com/davidmezzetti/cord-19-analysis-with-sentence-embeddings) notebook."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TONQ4_Kv9dtd",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 315
        },
        "outputId": "55f85ebf-946e-44aa-e27e-026436be909e"
      },
      "source": [
        "!wget https://github.com/neuml/txtai/releases/download/v1.1.0/tests.gz\n",
        "!gunzip tests.gz\n",
        "!mv tests articles.sqlite"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2020-08-25 01:56:04--  https://github.com/neuml/txtai/releases/download/v1.1.0/tests.gz\n",
            "Resolving github.com (github.com)... 140.82.112.4\n",
            "Connecting to github.com (github.com)|140.82.112.4|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/286301447/080d8800-e653-11ea-8d02-c0c858a09e7a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20200825%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20200825T015604Z&X-Amz-Expires=300&X-Amz-Signature=e135dc0b8f0d7774019d6525aad528b99160aff4007b2a539330c090c29ef9b5&X-Amz-SignedHeaders=host&actor_id=0&repo_id=286301447&response-content-disposition=attachment%3B%20filename%3Dtests.gz&response-content-type=application%2Foctet-stream [following]\n",
            "--2020-08-25 01:56:04--  https://github-production-release-asset-2e65be.s3.amazonaws.com/286301447/080d8800-e653-11ea-8d02-c0c858a09e7a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20200825%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20200825T015604Z&X-Amz-Expires=300&X-Amz-Signature=e135dc0b8f0d7774019d6525aad528b99160aff4007b2a539330c090c29ef9b5&X-Amz-SignedHeaders=host&actor_id=0&repo_id=286301447&response-content-disposition=attachment%3B%20filename%3Dtests.gz&response-content-type=application%2Foctet-stream\n",
            "Resolving github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)... 52.216.110.243\n",
            "Connecting to github-production-release-asset-2e65be.s3.amazonaws.com (github-production-release-asset-2e65be.s3.amazonaws.com)|52.216.110.243|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 2456199 (2.3M) [application/octet-stream]\n",
            "Saving to: ‘tests.gz’\n",
            "\n",
            "tests.gz            100%[===================>]   2.34M  --.-KB/s    in 0.09s   \n",
            "\n",
            "2020-08-25 01:56:05 (24.7 MB/s) - ‘tests.gz’ saved [2456199/2456199]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bzdaJiZYBIHE",
        "colab_type": "text"
      },
      "source": [
        "# Build Word Vectors\n",
        "\n",
        "This example will build a search system backed by word embeddings. While note quite as powerful as transformer embeddings, they often provide a good tradeoff of performance to functionality for an embedding based search system.\n",
        "\n",
        "For this notebook, we'll build our own custom embeddings for demo purposes. A number of pre-trained word embedding models are available:\n",
        "\n",
        " - [General language models from pymagnitude](https://github.com/plasticityai/magnitude)\n",
        " - [CORD-19 fastText](https://www.kaggle.com/davidmezzetti/cord19-fasttext-vectors)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fJcn-CAH-u3K",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 156
        },
        "outputId": "8ad62b92-55b7-4f2e-af55-f313402574a8"
      },
      "source": [
        "import os\n",
        "import sqlite3\n",
        "import tempfile\n",
        "\n",
        "from txtai.tokenizer import Tokenizer\n",
        "from txtai.vectors import WordVectors\n",
        "\n",
        "print(\"Streaming tokens to temporary file\")\n",
        "\n",
        "# Stream tokens to temp working file\n",
        "with tempfile.NamedTemporaryFile(mode=\"w\", suffix=\".txt\", delete=False) as output:\n",
        "  # Save file path\n",
        "  tokens = output.name\n",
        "\n",
        "  db = sqlite3.connect(\"articles.sqlite\")\n",
        "  cur = db.cursor()\n",
        "  cur.execute(\"SELECT Text from sections\")\n",
        "\n",
        "  for row in cur:\n",
        "    output.write(\" \".join(row[0]) + \"\\n\")\n",
        "\n",
        "  # Free database resources\n",
        "  db.close()\n",
        "\n",
        "# Build word vectors model - 300 dimensions, 3 min occurrences\n",
        "WordVectors.build(tokens, 300, 3, \"cord19-300d\")\n",
        "\n",
        "# Remove temporary tokens file\n",
        "os.remove(tokens)\n",
        "\n",
        "# Show files\n",
        "!ls -l"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Streaming tokens to temporary file\n",
            "Building 300 dimension model\n",
            "Converting vectors to magnitude format\n",
            "total 9024\n",
            "-rw-r--r-- 1 root root 8065024 Aug 25 01:44 articles.sqlite\n",
            "-rw-r--r-- 1 root root  360448 Aug 25 01:57 cord19-300d.magnitude\n",
            "-rw-r--r-- 1 root root  807886 Aug 25 01:57 cord19-300d.txt\n",
            "drwxr-xr-x 1 root root    4096 Jul 30 16:30 sample_data\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_UxcC1-JGH-d",
        "colab_type": "text"
      },
      "source": [
        "# Build an embeddings index\n",
        "\n",
        "The following steps builds an embeddings index using the word vector model just created. This model builds a BM25 + fastText index. BM25 is used to build a weighted average of the word embeddings for a section. More information on this method can be found in this [Medium article](https://towardsdatascience.com/building-a-sentence-embedding-index-with-fasttext-and-bm25-f07e7148d240?gi=79da927aa10). "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5PrrxGRPGHqX",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "36c78d6a-1f01-4e66-d0a8-7d27d3698ab3"
      },
      "source": [
        "import sqlite3\n",
        "\n",
        "import regex as re\n",
        "\n",
        "from txtai.embeddings import Embeddings\n",
        "from txtai.tokenizer import Tokenizer\n",
        "\n",
        "def stream():\n",
        "  # Connection to database file\n",
        "  db = sqlite3.connect(\"articles.sqlite\")\n",
        "  cur = db.cursor()\n",
        "\n",
        "  # Select tagged sentences without a NLP label. NLP labels are set for non-informative sentences.\n",
        "  cur.execute(\"SELECT Id, Name, Text FROM sections WHERE (labels is null or labels NOT IN ('FRAGMENT', 'QUESTION')) AND tags is not null\")\n",
        "\n",
        "  count = 0\n",
        "  for row in cur:\n",
        "    # Unpack row\n",
        "    uid, name, text = row\n",
        "\n",
        "    # Only process certain document sections\n",
        "    if not name or not re.search(r\"background|(?<!.*?results.*?)discussion|introduction|reference\", name.lower()):\n",
        "      # Tokenize text\n",
        "      tokens = Tokenizer.tokenize(text)\n",
        "\n",
        "      document = (uid, tokens, None)\n",
        "\n",
        "      count += 1\n",
        "      if count % 1000 == 0:\n",
        "        print(\"Streamed %d documents\" % (count), end=\"\\r\")\n",
        "\n",
        "      # Skip documents with no tokens parsed\n",
        "      if tokens:\n",
        "        yield document\n",
        "\n",
        "  print(\"Iterated over %d total rows\" % (count))\n",
        "\n",
        "  # Free database resources\n",
        "  db.close()\n",
        "\n",
        "# BM25 + fastText vectors\n",
        "embeddings = Embeddings({\"path\": \"cord19-300d.magnitude\",\n",
        "                         \"scoring\": \"bm25\",\n",
        "                         \"pca\": 3})\n",
        "\n",
        "# Build scoring index if scoring method provided\n",
        "if embeddings.config.get(\"scoring\"):\n",
        "  embeddings.score(stream())\n",
        "\n",
        "# Build embeddings index\n",
        "embeddings.index(stream())\n"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Iterated over 21499 total rows\n",
            "Iterated over 21499 total rows\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zHk24su3e_gb",
        "colab_type": "text"
      },
      "source": [
        "# Query data\n",
        "\n",
        "The following runs a query against the embeddings index for the terms \"risk factors\". It finds the top 5 matches and returns the corresponding documents associated with each match."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CRbDhvvDKEl-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 293
        },
        "outputId": "3f52b2b9-5307-47af-8dee-fdd9e694250b"
      },
      "source": [
        "import pandas as pd\n",
        "\n",
        "from IPython.display import display, HTML\n",
        "\n",
        "pd.set_option(\"display.max_colwidth\", None)\n",
        "\n",
        "db = sqlite3.connect(\"articles.sqlite\")\n",
        "cur = db.cursor()\n",
        "\n",
        "results = []\n",
        "for uid, score in embeddings.search(\"risk factors\", 5):\n",
        "  cur.execute(\"SELECT article, text FROM sections WHERE id = ?\", [uid])\n",
        "  uid, text = cur.fetchone()\n",
        "\n",
        "  cur.execute(\"SELECT Title, Published, Reference from articles where id = ?\", [uid])\n",
        "  results.append(cur.fetchone() + (text,))\n",
        "\n",
        "# Free database resources\n",
        "db.close()\n",
        "\n",
        "df = pd.DataFrame(results, columns=[\"Title\", \"Published\", \"Reference\", \"Match\"])\n",
        "\n",
        "display(HTML(df.to_html(index=False)))"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th>Title</th>\n",
              "      <th>Published</th>\n",
              "      <th>Reference</th>\n",
              "      <th>Match</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>Prevalence and Impact of Myocardial Injury in Patients Hospitalized with COVID-19 Infection</td>\n",
              "      <td>2020-04-24 00:00:00</td>\n",
              "      <td>http://medrxiv.org/cgi/content/short/2020.04.20.20072702v1?rss=1</td>\n",
              "      <td>This risk was consistent across patients stratified by history of CVD, risk factors but no CVD, and neither CVD nor risk factors.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>COVID-19 and associations with frailty and multimorbidity: a prospective analysis of UK Biobank participants</td>\n",
              "      <td>2020-07-23 00:00:00</td>\n",
              "      <td>https://www.ncbi.nlm.nih.gov/pubmed/32705587/</td>\n",
              "      <td>The identification of risk factors for contracting COVID-19 is crucial, to inform public health policy and to facilitate the appropriate distribution of healthcare resources.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>Quantitative evaluation of olfactory dysfunction in hospitalized patients with Coronavirus [2] (COVID-19)</td>\n",
              "      <td>2020-05-25 00:00:00</td>\n",
              "      <td>https://www.ncbi.nlm.nih.gov/pubmed/32451613/</td>\n",
              "      <td>In addition, these reports included patients with minor COVID-19 symptoms and low-risk factor burden.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>COVID-19 from the perspective of urban and rural general adult mental health services</td>\n",
              "      <td>2020-05-21 00:00:00</td>\n",
              "      <td>https://doi.org/10.1017/ipm.2020.62</td>\n",
              "      <td>At-risk groups among staff members and service users were identified early and prioritised in service changes.</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>Management of osteoarthritis during COVID‐19 pandemic</td>\n",
              "      <td>2020-05-21 00:00:00</td>\n",
              "      <td>https://doi.org/10.1002/cpt.1910</td>\n",
              "      <td>Consistently, a recent report indicated diabetes as a risk factor significantly associated with COVID-19 unfavourable clinical outcomes (37) .</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XSf68I-ZfXOG",
        "colab_type": "text"
      },
      "source": [
        "# Extracting additional columns from query results\n",
        "\n",
        "The example above uses the Embeddings index to find the top 5 best matches. In addition to this, an Extractor instance is used to ask additional questions over the search results, creating a richer query response."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TLVOTQJchvTi",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "%%capture\n",
        "from txtai.extractor import Extractor\n",
        "\n",
        "# Create extractor instance using qa model designed for the CORD-19 dataset\n",
        "extractor = Extractor(embeddings, \"NeuML/bert-small-cord19qa\")"
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "19fmKawThs6d",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 293
        },
        "outputId": "e5c24448-97b4-444f-d47e-d268ef18aec5"
      },
      "source": [
        "db = sqlite3.connect(\"articles.sqlite\")\n",
        "cur = db.cursor()\n",
        "\n",
        "results = []\n",
        "for uid, score in embeddings.search(\"risk factors\", 5):\n",
        "  cur.execute(\"SELECT article, text FROM sections WHERE id = ?\", [uid])\n",
        "  uid, text = cur.fetchone()\n",
        "\n",
        "  # Get list of document text sections to use for the context\n",
        "  cur.execute(\"SELECT Id, Name, Text FROM sections WHERE (labels is null or labels NOT IN ('FRAGMENT', 'QUESTION')) AND article = ?\", [uid])\n",
        "  sections = []\n",
        "  for sid, name, txt in cur.fetchall():\n",
        "    if not name or not re.search(r\"background|(?<!.*?results.*?)discussion|introduction|reference\", name.lower()):\n",
        "      sections.append((sid, txt))\n",
        "\n",
        "  cur.execute(\"SELECT Title, Published, Reference from articles where id = ?\", [uid])\n",
        "  article = cur.fetchone()\n",
        "\n",
        "  # Use QA extractor to derive additional columns\n",
        "  answers = extractor(sections, [(\"Risk Factors\", \"risk factors\", \"What risk factors?\", False),\n",
        "                                 (\"Locations\", \"hospital country\", \"What locations?\", False)])\n",
        "\n",
        "  results.append(article + (text,) + tuple([answer[1] for answer in answers]))\n",
        "\n",
        "# Free database resources\n",
        "db.close()\n",
        "\n",
        "df = pd.DataFrame(results, columns=[\"Title\", \"Published\", \"Reference\", \"Match\", \"Risk Factors\", \"Locations\"])\n",
        "display(HTML(df.to_html(index=False)))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th>Title</th>\n",
              "      <th>Published</th>\n",
              "      <th>Reference</th>\n",
              "      <th>Match</th>\n",
              "      <th>Risk Factors</th>\n",
              "      <th>Locations</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>Prevalence and Impact of Myocardial Injury in Patients Hospitalized with COVID-19 Infection</td>\n",
              "      <td>2020-04-24 00:00:00</td>\n",
              "      <td>http://medrxiv.org/cgi/content/short/2020.04.20.20072702v1?rss=1</td>\n",
              "      <td>This risk was consistent across patients stratified by history of CVD, risk factors but no CVD, and neither CVD nor risk factors.</td>\n",
              "      <td>neither CVD nor risk factors</td>\n",
              "      <td>New York City hospitals</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>COVID-19 and associations with frailty and multimorbidity: a prospective analysis of UK Biobank participants</td>\n",
              "      <td>2020-07-23 00:00:00</td>\n",
              "      <td>https://www.ncbi.nlm.nih.gov/pubmed/32705587/</td>\n",
              "      <td>The identification of risk factors for contracting COVID-19 is crucial, to inform public health policy and to facilitate the appropriate distribution of healthcare resources.</td>\n",
              "      <td>Frailty and multimorbidity</td>\n",
              "      <td>hospital settings</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>Quantitative evaluation of olfactory dysfunction in hospitalized patients with Coronavirus [2] (COVID-19)</td>\n",
              "      <td>2020-05-25 00:00:00</td>\n",
              "      <td>https://www.ncbi.nlm.nih.gov/pubmed/32451613/</td>\n",
              "      <td>In addition, these reports included patients with minor COVID-19 symptoms and low-risk factor burden.</td>\n",
              "      <td>patients with minor COVID-19 symptoms and low-risk factor burden</td>\n",
              "      <td>COVID-19 wards</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>COVID-19 from the perspective of urban and rural general adult mental health services</td>\n",
              "      <td>2020-05-21 00:00:00</td>\n",
              "      <td>https://doi.org/10.1017/ipm.2020.62</td>\n",
              "      <td>At-risk groups among staff members and service users were identified early and prioritised in service changes.</td>\n",
              "      <td>At-risk groups among staff members and service users</td>\n",
              "      <td>rural regions</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>Management of osteoarthritis during COVID‐19 pandemic</td>\n",
              "      <td>2020-05-21 00:00:00</td>\n",
              "      <td>https://doi.org/10.1002/cpt.1910</td>\n",
              "      <td>Consistently, a recent report indicated diabetes as a risk factor significantly associated with COVID-19 unfavourable clinical outcomes (37) .</td>\n",
              "      <td>sex, obesity, genetic factors and mechanical factors</td>\n",
              "      <td>None</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ColTLy--rWfR",
        "colab_type": "text"
      },
      "source": [
        "In the example above, the Embeddings index is used to find the top N results for a given query. On top of that, a question-answer extractor is used to derive additional columns based on a list of questions. In this case, the \"Risk Factors\" and \"Location\" columns were pulled from the document text."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KWyoysauy7Pr",
        "colab_type": "text"
      },
      "source": [
        "# Next\n",
        "In part 4 of this series, we'll use combine the power of Elasticsearch with Extractive QA to build a large-scale, advanced search system.\n"
      ]
    }
  ]
}